import os
import sys

os.chdir(sys.path[0])
sys.path.append("../../")
os.getcwd()

import argparse
import logging
import pickle
import random
import time
from copy import deepcopy
from pprint import pformat

import numpy as np
import torch
import yaml

# from models.resnet import ResNet18, ResNet18Extractor, ResNet34, ResNet34Extractor, ResNet50, ResNet50Extractor
from models.resnet import ResNet18, ResNet34, ResNet50
from selection.utils.selection_dataset import NewMixBackdoorDataset
from utils.aggregate_block.bd_attack_generate import bd_attack_img_trans_generate, bd_attack_label_trans_generate
from utils.aggregate_block.dataset_and_transform_generate import dataset_and_transform_generate
from utils.aggregate_block.model_trainer_generate import generate_cls_model
from utils.aggregate_block.save_path_generate import generate_save_folder
from utils.backdoor_generate_poison_index import generate_poison_index_from_label_transform
from utils.bd_dataset_v2 import dataset_wrapper_with_transform, get_labels, prepro_cls_DatasetBD_v2


def get_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, help="cuda|cpu", default="cuda:0")
    parser.add_argument("--selection_yaml_path", type=str, default="../selection/selection_config.yaml")
    parser.add_argument("--save_folder_name", type=str)
    parser.add_argument("--random_seed", type=str)
    parser.add_argument("--model", type=str)
    parser.add_argument("--batch_size", type=int)
    parser.add_argument("--lr", type=float)
    parser.add_argument("--epochs", type=int)
    parser.add_argument("--dataset_path", type=str)
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--num_classes", type=int)
    parser.add_argument("--input_height", type=int)
    parser.add_argument("--input_width", type=int)
    parser.add_argument("--input_channel", type=int)
    parser.add_argument("--pratio", type=float)
    parser.add_argument("--attack", type=str)
    parser.add_argument("--attack_target", type=int)
    parser.add_argument("--attack_label_trans", type=str)
    parser.add_argument("--poison_idx_path", type=str)

    return parser


def preprocess_args(args):
    

    # preprocess args for dataset
    args.dataset_path = f"{args.dataset_path}/{args.dataset}"

    if args.dataset == "mnist":
        args.num_classes = 10
        args.input_height = 28
        args.input_width = 28
        args.input_channel = 1

    elif args.dataset == "cifar10":
        args.num_classes = 10
        args.input_height = 32
        args.input_width = 32
        args.input_channel = 3

    elif args.dataset == "cifar100":
        args.num_classes = 100
        args.input_height = 32
        args.input_width = 32
        args.input_channel = 3

    elif args.dataset == "gtsrb":
        args.num_classes = 43
        args.input_height = 32
        args.input_width = 32
        args.input_channel = 3

    elif args.dataset == "celeba":
        args.num_classes = 8
        args.input_height = 64
        args.input_width = 64
        args.input_channel = 3

    elif args.dataset == "tiny":
        args.num_classes = 200
        args.input_height = 64
        args.input_width = 64
        args.input_channel = 3
    else:
        raise Exception("Invalid Dataset")
    args.img_size = (args.input_height, args.input_width, args.input_channel)

    if "save_folder_name" not in args:
        save_path = generate_save_folder(
            run_info=("afterwards" if "load_path" in args.__dict__ else "attack")
            + "_"
            + (args.attack if "attack" in args.__dict__ else "prototype"),
            given_load_file_path=args.load_path if "load_path" in args else None,
            all_record_folder_path="../record",
        )
    else:
        save_path = "../record/" + args.save_folder_name
        if not os.path.exists(save_path):
            os.makedirs(save_path)
    args.save_path = save_path

    torch.save(args.__dict__, save_path + "/info.pickle")

    ### set the logger
    logFormatter = logging.Formatter(
        fmt="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
        datefmt="%Y-%m-%d:%H:%M:%S",
    )
    logger = logging.getLogger()
    # file Handler
    fileHandler = logging.FileHandler(save_path + "/" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log")
    fileHandler.setFormatter(logFormatter)
    fileHandler.setLevel(logging.INFO)
    logger.addHandler(fileHandler)
    # consoleHandler
    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(logFormatter)
    consoleHandler.setLevel(logging.INFO)
    logger.addHandler(consoleHandler)
    # overall logger level should <= min(handler) otherwise no log will be recorded.
    logger.setLevel(logging.INFO)
    # disable other debug, since too many debug
    logging.getLogger("PIL").setLevel(logging.WARNING)
    logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
    logging.info(pformat(args.__dict__))

    return args


def generate_clean_dataset(args):
    (
        train_dataset_without_transform,
        train_img_transform,
        train_label_transform,
        test_dataset_without_transform,
        test_img_transform,
        test_label_transform,
    ) = dataset_and_transform_generate(args)

    clean_train_dataset_with_transform = dataset_wrapper_with_transform(
        train_dataset_without_transform, train_img_transform, train_label_transform
    )

    clean_train_dataset_targets = get_labels(train_dataset_without_transform)

    clean_test_dataset_with_transform = dataset_wrapper_with_transform(
        test_dataset_without_transform,
        test_img_transform,
        test_label_transform,
    )

    clean_test_dataset_targets = get_labels(test_dataset_without_transform)

    return (
        train_dataset_without_transform,
        train_img_transform,
        train_label_transform,
        test_dataset_without_transform,
        test_img_transform,
        test_label_transform,
        clean_train_dataset_with_transform,
        clean_train_dataset_targets,
        clean_test_dataset_with_transform,
        clean_test_dataset_targets,
    )


def generate_backdoor_dataset(
    args,
    clean_train_dataset_targets,
    clean_test_dataset_targets,
    train_dataset_without_transform,
    test_dataset_without_transform,
    train_img_transform,
    train_label_transform,
    test_img_transform,
    test_label_transform,
    train_poison_index,
    test_poison_index,
    save_folder_path=None,
):
    train_bd_img_transform, test_bd_img_transform = bd_attack_img_trans_generate(args)
    bd_label_transform = bd_attack_label_trans_generate(args)

    ### generate train dataset for backdoor attack
    bd_train_dataset = prepro_cls_DatasetBD_v2(
        deepcopy(train_dataset_without_transform),
        poison_indicator=train_poison_index,
        bd_image_pre_transform=train_bd_img_transform,
        bd_label_pre_transform=bd_label_transform,
        save_folder_path=save_folder_path,
    )

    bd_train_dataset_with_transform = dataset_wrapper_with_transform(
        bd_train_dataset,
        train_img_transform,
        train_label_transform,
    )

    ### generate test dataset for ASR
    bd_test_dataset = prepro_cls_DatasetBD_v2(
        deepcopy(test_dataset_without_transform),
        poison_indicator=test_poison_index,
        bd_image_pre_transform=test_bd_img_transform,
        bd_label_pre_transform=bd_label_transform,
        save_folder_path=save_folder_path,
    )

    bd_test_dataset.subset(np.where(test_poison_index == 1)[0])

    bd_test_dataset_with_transform = dataset_wrapper_with_transform(
        bd_test_dataset,
        test_img_transform,
        test_label_transform,
    )
    return (bd_train_dataset_with_transform, bd_test_dataset_with_transform)


def generate_new_mix_dataset(clean_train_dataset_with_transform, bd_train_dataset_with_transform):
    dataset = NewMixBackdoorDataset(
        clean_dataset_with_transform=clean_train_dataset_with_transform,
        bd_dataset_with_transform=bd_train_dataset_with_transform,
    )
    return dataset


def set_trigger(args):
    if args.attack == "badnet" and args.dataset != "tiny":
        args.patch_mask_path = args.badnet_patch_mask_path
    if args.attack == "badnet" and args.dataset == "tiny":
        args.patch_mask_path = '../resource/badnet/trigger_image_64.png'
    if args.attack == "badnet_one_pixel":
        args.patch_mask_path = args.badnet_one_pixel_patch_mask_path
    if args.attack == "blended":
        args.attack_trigger_img_path = args.blended_attack_trigger_img_path
        args.attack_train_blended_alpha = args.blended_attack_train_blended_alpha
        args.attack_test_blended_alpha = args.blended_attack_test_blended_alpha
    if args.attack == "ssba":
        args.attack_train_replace_imgs_path = f"../resource/ssba/{args.dataset}_ssba_train_b1.npy"
        args.attack_test_replace_imgs_path = f"../resource/ssba/{args.dataset}_ssba_test_b1.npy"
    if args.attack == "trojanwm":
        args.attack_trigger_img_path = args.trojanwm_attack_trigger_img_path
        args.attack_train_blended_alpha = args.trojanwm_attack_train_blended_alpha
        args.attack_test_blended_alpha = args.trojanwm_attack_test_blended_alpha
    if args.attack == "sig_clean":
        args.clean_label = True
    if args.attack=='sig' and args.dataset =='tiny':
        args.sig_delta=10
    return args


def generate_random_poison_idx(
    args, clean_train_dataset_targets, clean_test_dataset_targets, bd_label_transform, pratio, clean_label=False
):
    if args.poison_idx_path is not None:
        train_poison_index = np.load(f"../record/{args.poison_idx_path}")
    else:
        train_poison_index = generate_poison_index_from_label_transform(
            clean_train_dataset_targets,
            label_transform=bd_label_transform,
            train=True,
            pratio=pratio,
            p_num=None,
            clean_label=clean_label,
        )

    test_poison_index = generate_poison_index_from_label_transform(
        clean_test_dataset_targets, label_transform=bd_label_transform, train=False
    )

    return train_poison_index, test_poison_index


def distribute_bd_sample(count, num_classes):
    res = [0] * num_classes
    if count >= num_classes:
        for i in range(num_classes):
            if i < count % num_classes:
                res[i] = count // num_classes + 1
            else:
                res[i] = count // num_classes
    else:
        # randomly distribute
        res = np.array(res)
        res[random.sample(range(num_classes), count)] = 1
        res = list(res)
    return res


def save_pickle(data, file_path):
    with open(file_path, "wb") as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)


def load_pickle(file_path):
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return data


def loss_with_mask(all_bd_loss, all_cl_loss, mask):
    mask = torch.tensor(mask).float()
    all_loss = (torch.dot(mask, all_bd_loss) + torch.dot(1 - mask, all_cl_loss)) / len(all_bd_loss)
    bd_loss = torch.dot(mask, all_bd_loss) / torch.sum(mask)
    cl_loss = torch.dot(1 - mask, all_cl_loss) / torch.sum(1 - mask)
    return all_loss, bd_loss, cl_loss


def all_acc(preds: torch.Tensor, labels: torch.Tensor):
    if len(preds) == 0 or len(labels) == 0:
        logging.warning("zero len array in func all_acc(), return None!")
        return None
    return preds.eq(labels).sum().item() / len(preds)


if __name__ == "__main__":
    res = distribute_bd_sample(50000 * 0.008, 99)
    print(res)
